{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8ad0609b",
   "metadata": {},
   "source": [
    "## Comp2-1 量表类型数据\n",
    "\n",
    "主要适配于临床数据的建模和刻画。典型的应用场景探究rad_score最终临床诊断的作用。\n",
    "\n",
    "## Onekey步骤\n",
    "\n",
    "1. 数据校验，检查数据格式是否正确。\n",
    "3. 查看一些统计信息，检查数据时候存在异常点。\n",
    "4. 正则化，将数据变化到服从 N~(0, 1)。\n",
    "5. 通过相关系数，例如spearman、person等筛选出特征。\n",
    "6. 构建训练集和测试集，这里使用的是随机划分，正常多中心验证，需要大家根据自己的场景构建两份数据。\n",
    "7. 通过Lasso筛选特征，选取其中的非0项作为后续模型的特征。\n",
    "8. 使用机器学习算法，例如LR、SVM、RF等进行任务学习。\n",
    "9. 模型结果可视化，例如AUC、ROC曲线，混淆矩阵等。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10505b86",
   "metadata": {},
   "source": [
    "# 筛选临床特征\n",
    "\n",
    "需要根据自己的情况，筛选特征，一般情况筛选pvalue<0.05的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0769d859",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from collections import namedtuple\n",
    "from onekey_algo import OnekeyDS as okds\n",
    "from onekey_algo import get_param_in_cwd\n",
    "import onekey_algo.custom.components as okcomp\n",
    "from onekey_algo.custom.components.comp1 import fillna\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.rcParams['figure.dpi'] = 300"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b09b462",
   "metadata": {},
   "source": [
    "## 一、数据校验\n",
    "首先需要检查诊断数据，如果显示`检查通过！`择可以正常运行之后的，否则请根据提示调整数据。\n",
    "\n",
    "数据文件中的数据都是数值类型，或者可以映射成数值类型，这里的`label`某些情况下可能是非数值的，需要自定义数值函数。\n",
    "\n",
    "**注意：在使用树模型时，可以存在缺失，但是线性模型不允许缺失，请自行根据需要填充缺省值**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a30ee99b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取数据，B超诊断阳性=1，bc_data.csv是要读取的数据。\n",
    "task_type = 'Clinical'\n",
    "# 筛不出来临床特征\n",
    "data_file = r'data/clinical.csv'\n",
    "# 筛选出来\n",
    "data_file = r'data/clinic_sel.csv'\n",
    "labels = [get_param_in_cwd('label_col', 'label')]\n",
    "group_info = get_param_in_cwd('dataset_col', 'group')\n",
    "featrues_not_use = ['ID']\n",
    "\n",
    "structed_data = pd.read_csv(data_file, header=0)\n",
    "structed_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6a9d023",
   "metadata": {},
   "source": [
    "### 特征维度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb757659",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 删掉ID这一列。\n",
    "ids = structed_data['ID']\n",
    "structed_data = structed_data.drop(featrues_not_use, axis=1)\n",
    "structed_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17517726",
   "metadata": {},
   "source": [
    "## 二、数据统计\n",
    "\n",
    "1. count，统计样本个数。\n",
    "2. mean、std, 对应特征的均值、方差\n",
    "3. min, 25%, 50%, 75%, max，对应特征的最小值，25,50,75分位数，最大值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b25dbc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "structed_data.describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a480b2f3",
   "metadata": {},
   "source": [
    "## 三、正则化\n",
    "\n",
    "临床特征可选是否正则化，默认不进行\n",
    "\n",
    "`normalize_df` 为onekey中正则化的API，将数据变化到0均值1方差。正则化的方法为\n",
    "\n",
    "$column = \\frac{column - mean}{std}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1ad46b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from onekey_algo.custom.components.comp1 import normalize_df\n",
    "# data = normalize_df(structed_data, not_norm=labels + ['group'])\n",
    "# data = data.dropna(axis=1)\n",
    "# data.describe()\n",
    "data = structed_data\n",
    "data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3eba2a84",
   "metadata": {},
   "source": [
    "## 四、相关系数\n",
    "\n",
    "计算相关系数的方法有3种可供选择\n",
    "1. pearson （皮尔逊相关系数）: standard correlation coefficient\n",
    "\n",
    "2. kendall (肯德尔相关性系数) : Kendall Tau correlation coefficient\n",
    "\n",
    "3. spearman (斯皮尔曼相关性系数): Spearman rank correlation\n",
    "\n",
    "三种相关系数参考：https://blog.csdn.net/zmqsdu9001/article/details/82840332"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4897b00c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 如果需要选择相关系数使用对应的相关系数即可。\n",
    "# pearson_corr = data.corr('pearson')\n",
    "# kendall_corr = data.corr('kendall')\n",
    "spearman_corr = data.corr('spearman')\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from onekey_algo.custom.components.comp1 import draw_matrix\n",
    "plt.figure(figsize=(10.0, 8.0))\n",
    "\n",
    "# 选择可视化的相关系数\n",
    "draw_matrix(spearman_corr, annot=True, cmap='YlGnBu', cbar=False)\n",
    "plt.savefig(f'img/Clinic_feature_corr.svg', bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f40b311",
   "metadata": {},
   "source": [
    "## 五、构建数据\n",
    "\n",
    "将样本的训练数据X与监督信息y分离出来，并且对训练数据进行划分，一般的划分原则为80%-20%"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10d0dc73",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import onekey_algo.custom.components as okcomp\n",
    "from collections import OrderedDict\n",
    "n_classes = 2\n",
    "sel_data = data\n",
    "train_data = sel_data[(sel_data[group_info] == 'train')]\n",
    "train_ids = ids[train_data.index]\n",
    "train_data = train_data.reset_index()\n",
    "train_data = train_data.drop('index', axis=1)\n",
    "y_data = train_data[labels]\n",
    "X_data = train_data.drop(labels + [group_info], axis=1)\n",
    "\n",
    "subsets = [s for s in get_param_in_cwd('subsets') if s != 'train']\n",
    "val_datasets = OrderedDict()\n",
    "for subset in subsets:\n",
    "    val_data = sel_data[sel_data[group_info] == subset]\n",
    "    val_ids = ids[val_data.index]\n",
    "    val_data = val_data.reset_index()\n",
    "    val_data = val_data.drop('index', axis=1)\n",
    "    y_val_data = val_data[labels]\n",
    "    X_val_data = val_data.drop(labels + [group_info], axis=1)\n",
    "    val_datasets[subset] = [X_val_data, y_val_data, val_ids]\n",
    "\n",
    "y_all_data = sel_data[labels]\n",
    "X_all_data = sel_data.drop(labels + [group_info], axis=1)\n",
    "\n",
    "column_names = X_data.columns\n",
    "print(f\"训练集样本数：{X_data.shape}，\", '，'.join([f\"{subset}样本数：{d_[0].shape}\" for subset, d_ in val_datasets.items()]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "282e4f7f",
   "metadata": {},
   "source": [
    "## 六、模型筛选\n",
    "\n",
    "根据筛选出来的数据，做模型的初步选择。当前主要使用到的是Onekey中的\n",
    "\n",
    "1. SVM，支持向量机，引用参考。\n",
    "2. KNN，K紧邻，引用参考。\n",
    "3. Decision Tree，决策树，引用参考。\n",
    "4. Random Forests, 随机森林，引用参考。\n",
    "5. XGBoost, bosting方法。引用参考。\n",
    "6. LightGBM, bosting方法，引用参考。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "583daa3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_names = get_param_in_cwd('ml_models') or ['SVM', 'KNN', 'RandomForest', 'ExtraTrees', 'XGBoost', 'LightGBM', 'MLP', 'LR']\n",
    "models = okcomp.comp1.create_clf_model(model_names)\n",
    "model_names = list(models.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7123d8a2",
   "metadata": {},
   "source": [
    "### 交叉验证\n",
    "\n",
    "`n_trails`指定随机次数，每次采用的是80%训练，随机20%进行测试，找到最好的模型，以及对应的最好的数据划分。\n",
    "\n",
    "这里的数据并没有使用前面`Lasso`筛选出来的特征进行训练，理论来说，特征筛选仅对线性模型有一定作用，例如`SVM`、`LR`，但是对树模型没什么作用，例如`DecisionTree`、`Random`这些。所以默认不筛选。\n",
    "\n",
    "```python\n",
    "def get_bst_split(X_data: pd.DataFrame, y_data: pd.DataFrame,\n",
    "            models: dict, test_size=0.2, metric_fn=accuracy_score, n_trails=10,\n",
    "            cv: bool = False, shuffle: bool = False, metric_cut_off: float = None, random_state=None):\n",
    "    \"\"\"\n",
    "    寻找数据集中最好的数据划分。\n",
    "    Args:\n",
    "        X_data: 训练数据\n",
    "        y_data: 监督数据\n",
    "        models: 模型名称，Dict类型、\n",
    "        test_size: 测试集比例\n",
    "        metric_fn: 评价模型好坏的函数，默认准确率，可选roc_auc_score。\n",
    "        n_trails: 尝试多少次寻找最佳数据集划分。\n",
    "        cv: 是否是交叉验证，默认是False，当为True时，n_trails为交叉验证的n_fold\n",
    "        shuffle: 是否进行随机打乱\n",
    "        metric_cut_off: 当metric_fn的值达到多少时进行截断。\n",
    "        random_state: 随机种子\n",
    "\n",
    "    Returns: {'max_idx': max_idx, \"max_model\": max_model, \"max_metric\": max_metric, \"results\": results}\n",
    "\n",
    "    \"\"\"\n",
    "```\n",
    "\n",
    "**注意：这里采用了【挑数据】，如果想要严谨，请修改`n_trails=1`。**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c747bb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from sklearn.metrics import accuracy_score, roc_auc_score\n",
    "\n",
    "# 随机使用n_trails次数据划分，找到最好的一次划分方法，并且保存在results中。\n",
    "results = okcomp.comp1.get_bst_split(X_data, y_data, models, test_size=0.2, metric_fn=roc_auc_score, n_trails=5, cv=False, \n",
    "                                     random_state=75)\n",
    "# _, (X_train_sel, X_test_sel, y_train_sel, y_test_sel) = results['results'][results['max_idx']]\n",
    "trails, _ = zip(*results['results'])\n",
    "cv_results = pd.DataFrame(trails, columns=model_names)\n",
    "# 可视化每个模型在不同的数据划分中的效果。\n",
    "sns.barplot(data=cv_results)\n",
    "plt.ylabel('AUC %')\n",
    "plt.xlabel('Model Nmae')\n",
    "# plt.xticks(rotation=90)\n",
    "plt.ylim(0.5,)\n",
    "plt.savefig(f'img/{task_type}_model_cv.svg', bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1555a87f",
   "metadata": {},
   "source": [
    "### 模型筛选\n",
    "\n",
    "使用最好的数据划分，进行后续的模型研究。\n",
    "\n",
    "**注意**: 一般情况下论文使用的是随机划分的数据，但也有些论文使用【刻意】筛选的数据划分。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc872d56",
   "metadata": {},
   "outputs": [],
   "source": [
    "import joblib\n",
    "import os\n",
    "from onekey_algo.custom.components.comp1 import plot_feature_importance, smote_resample\n",
    "from sklearn.ensemble import ExtraTreesClassifier, RandomForestClassifier\n",
    "from xgboost import XGBClassifier\n",
    "from lightgbm import LGBMClassifier\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "targets = []\n",
    "os.makedirs('models', exist_ok=True)\n",
    "for l in labels:\n",
    "    new_models = okcomp.comp1.create_clf_model(model_names)\n",
    "#     new_models['LR'] = LogisticRegression(penalty='l2', max_iter=8, C=0.8)\n",
    "#     new_models['LightGBM'] = LGBMClassifier(n_estimators=1, max_depth=2, random_state=0)\n",
    "#     new_models['RandomForest'] = RandomForestClassifier(n_estimators=7, max_depth=4, min_samples_split=2, random_state=0)\n",
    "#     new_models['ExtraTrees'] = ExtraTreesClassifier(n_estimators=8, max_depth=2, min_samples_split=2, random_state=0)\n",
    "#     new_models['XGBoost'] = XGBClassifier(n_estimators=1, objective='binary:logistic', max_depth=2, min_child_weight=.2,\n",
    "#                                               use_label_encoder=False, eval_metric='error')\n",
    "    model_names = list(new_models.keys())\n",
    "    new_models = list(new_models.values())\n",
    "    for mn, m in zip(model_names, new_models):        \n",
    "        X_train_sel, y_train_sel = X_data, y_data\n",
    "        if get_param_in_cwd('use_smote', False):\n",
    "            X_train_sel, y_train_sel = smote_resample(X_train_sel, y_train_sel)\n",
    "        m.fit(X_train_sel, y_train_sel[l])\n",
    "        # 保存训练的模型\n",
    "#         joblib.dump(m, f'models/{task_type}_{mn}_{l}.pkl') \n",
    "        # 输出模型特征重要性，只针对高级树模型有用\n",
    "#         plot_feature_importance(m, selected_features[0], save_dir='img', prefix=f\"{task_type}_\")\n",
    "    targets.append(new_models)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44b6dbb9",
   "metadata": {},
   "source": [
    "## 七、预测结果\n",
    "\n",
    "* predictions，二维数据，每个label对应的每个模型的预测结果。\n",
    "* pred_scores，二维数据，每个label对应的每个模型的预测概率值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae5a1cd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from onekey_algo.custom.components.delong import calc_95_CI\n",
    "from onekey_algo.custom.components.metrics import analysis_pred_binary\n",
    "\n",
    "metric = []\n",
    "pred_sel_idx = []\n",
    "X_train_sel, y_train_sel = X_data, y_data\n",
    "predictions = [[(model.predict(X_train_sel), \n",
    "                 [(model.predict(X_val_sel), y_val_sel) for X_val_sel, y_val_sel, _ in val_datasets.values()])  \n",
    "                for model in target] for label, target in zip(labels, targets)]\n",
    "pred_scores = [[(model.predict_proba(X_train_sel), \n",
    "                 [(model.predict_proba(X_val_sel), y_val_sel) for X_val_sel, y_val_sel, _ in val_datasets.values()]) \n",
    "                for model in target] for label, target in zip(labels, targets)]\n",
    "\n",
    "for label, prediction, scores in zip(labels, predictions, pred_scores):\n",
    "    for mname, (train_pred, val_preds), (train_score, val_scores) in zip(model_names, prediction, scores):\n",
    "        # 计算训练集指数\n",
    "        acc, auc, ci, tpr, tnr, ppv, npv, precision, recall, f1, thres = analysis_pred_binary(y_train_sel[label], \n",
    "                                                                                              train_score[:, 1])\n",
    "        ci = f\"{ci[0]:.3f} - {ci[1]:.3f}\"\n",
    "        metric.append((mname, acc, auc, ci, tpr, tnr, ppv, npv, thres, f\"train\"))\n",
    "        for subset, (val_score, y_val_sel) in zip(val_datasets.keys(), val_scores):\n",
    "            acc, auc, ci, tpr, tnr, ppv, npv, precision, recall, f1, thres = analysis_pred_binary(y_val_sel[label], \n",
    "                                                                                                  val_score[:, 1])\n",
    "            ci = f\"{ci[0]:.3f} - {ci[1]:.3f}\"\n",
    "            metric.append((mname, acc, auc, ci, tpr, tnr, ppv, npv, thres, subset))\n",
    "metric = pd.DataFrame(metric, index=None, columns=['model_name', 'Accuracy', 'AUC', '95% CI',\n",
    "                                                   'Sensitivity', 'Specificity', 'PPV', 'NPV', 'Threshold', 'Cohort'])\n",
    "metric[['model_name', 'Accuracy', 'AUC', '95% CI', 'Sensitivity', 'Specificity', 'PPV', 'NPV', 'Cohort']].to_csv('results/Clinic_metrics.csv',\n",
    "                                                                                                            index=False)\n",
    "metric"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8222fe3c",
   "metadata": {},
   "source": [
    "### 绘制曲线\n",
    "\n",
    "绘制的不同模型的准确率柱状图和折线图曲线。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d18573e1",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.subplot(211)\n",
    "sns.barplot(x='model_name', y='Accuracy', data=metric, hue='Cohort')\n",
    "plt.subplot(212)\n",
    "sns.lineplot(x='model_name', y='Accuracy', data=metric, hue='Cohort')\n",
    "plt.savefig(f'img/{task_type}_model_acc.svg', bbox_inches = 'tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76c84d6b",
   "metadata": {},
   "source": [
    "## 绘制ROC曲线\n",
    "确定最好的模型，并且绘制曲线。\n",
    "\n",
    "```python\n",
    "def draw_roc(y_test, y_score, title='ROC', labels=None):\n",
    "```\n",
    "\n",
    "`sel_model = ['SVM', 'KNN']`参数为想要绘制的模型对应的参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60526844",
   "metadata": {},
   "outputs": [],
   "source": [
    "sel_model = model_names\n",
    "\n",
    "for sm in sel_model:\n",
    "    if sm in model_names:\n",
    "        sel_model_idx = model_names.index(sm)\n",
    "    \n",
    "        # Plot all ROC curves\n",
    "        plt.figure(figsize=(10, 10))\n",
    "        for pred_score, label in zip(pred_scores, labels):\n",
    "            ys = [np.array(y_train_sel[label])] + [np.array(y_val_sel[label]) for _, y_val_sel, _ in val_datasets.values()]\n",
    "            ps = [pred_score[sel_model_idx][0]] + [p_[0] for p_ in pred_score[sel_model_idx][1]]\n",
    "            okcomp.comp1.draw_roc(ys, ps, \n",
    "                                  labels=['train'] + list(val_datasets.keys()), title=f\"Model: {sm}\")\n",
    "            plt.savefig(f'img/{task_type}_model_{sm}_roc.svg', bbox_inches = 'tight')\n",
    "            plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c8732da",
   "metadata": {},
   "source": [
    "### 汇总所有模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8054e9a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "sel_model = model_names\n",
    "\n",
    "for pred_score, label in zip(pred_scores, labels):\n",
    "    pred_val_scores = []\n",
    "    for sm in sel_model:\n",
    "        if sm in model_names:\n",
    "            sel_model_idx = model_names.index(sm)\n",
    "            pred_val_scores.append((pred_score[sel_model_idx][0], y_train_sel))\n",
    "    p_, l_ = zip(*pred_val_scores)\n",
    "    okcomp.comp1.draw_roc(l_, p_, labels=sel_model, title=f\"Cohort train AUC\")\n",
    "    plt.savefig(f'img/{task_type}_model_train_roc.svg', bbox_inches = 'tight')\n",
    "    plt.show()\n",
    "    for sel_subset_idx, subset in enumerate(val_datasets.keys()):\n",
    "        pred_val_scores = []\n",
    "        for sm in sel_model:\n",
    "            if sm in model_names:\n",
    "                sel_model_idx = model_names.index(sm)\n",
    "                pred_val_scores.append(pred_score[sel_model_idx][1][sel_subset_idx])\n",
    "        p_, l_ = zip(*pred_val_scores)\n",
    "        okcomp.comp1.draw_roc(l_, p_, labels=sel_model, title=f\"Cohort {subset} AUC\")\n",
    "        plt.savefig(f'img/{task_type}_model_{subset}_roc.svg', bbox_inches = 'tight')\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb20d7bd",
   "metadata": {},
   "source": [
    "### DCA决策曲线"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08820a8a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp1 import plot_DCA, draw_calibration\n",
    "# 设置绘制参数\n",
    "sel_model = model_names\n",
    "for sm in sel_model:\n",
    "    if sm in model_names:\n",
    "        sel_model_idx = model_names.index(sm)\n",
    "        for idx, label in enumerate(labels):\n",
    "            try:\n",
    "                for sel_subset_idx, subset in enumerate(val_datasets.keys()):\n",
    "                    p_ = pred_scores[idx][sel_model_idx][0]\n",
    "                    plot_DCA([p_[:, 0]], y_train_sel[label], title=f'Model for DCA', labels=sm, remap=True, y_min=-0.15, )\n",
    "                    plt.savefig(f'img/{task_type}_model_train_{sm}_dca.svg', bbox_inches = 'tight')\n",
    "                    plt.show()\n",
    "                    draw_calibration(pred_scores=[p_[:, 0]], n_bins=5, remap=True,\n",
    "                                     y_test=[y_train_sel[label]], model_names=[sm])\n",
    "                    plt.savefig(f'img/{task_type}_model_train_{sm}_cali.svg', bbox_inches = 'tight')\n",
    "                    plt.show()\n",
    "                for sel_subset_idx, subset in enumerate(val_datasets.keys()):\n",
    "                    p_, l_ = pred_scores[idx][sel_model_idx][1][sel_subset_idx]\n",
    "                    plot_DCA([p_[:, 0]], l_[label], title=f'Model for DCA', labels=sm, remap=True, y_min=-0.15, )\n",
    "                    plt.savefig(f'img/{task_type}_model_{subset}_{sm}_dca.svg', bbox_inches = 'tight')\n",
    "                    plt.show()\n",
    "                    draw_calibration(pred_scores=[p_[:, 0]], n_bins=5, remap=True,\n",
    "                                     y_test=[l_[label]], model_names=[sm])\n",
    "                    plt.savefig(f'img/{task_type}_model_{subset}_{sm}_cali.svg', bbox_inches = 'tight')\n",
    "                    plt.show()\n",
    "            except:\n",
    "                pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb5d0e3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "os.makedirs('results', exist_ok=True)\n",
    "sel_model = sel_model\n",
    "\n",
    "all_log = []\n",
    "for idx, label in enumerate(labels):\n",
    "    for sm in sel_model:\n",
    "        if sm in model_names:\n",
    "            sel_model_idx = model_names.index(sm)\n",
    "            target = targets[idx][sel_model_idx]\n",
    "            # 预测训练集和测试集数据。\n",
    "            train_indexes = np.reshape(np.array(train_ids), (-1, 1)).astype(str)\n",
    "            # 保存预测的训练集和测试集结果\n",
    "            y_train_pred_scores = target.predict_proba(X_data)\n",
    "            columns = ['ID'] + [f\"{label}-{i}\"for i in range(y_train_pred_scores.shape[1])]\n",
    "            result_train = pd.DataFrame(np.concatenate([train_indexes, y_train_pred_scores], axis=1), columns=columns)\n",
    "            result_train.to_csv(f'results/{task_type}_{sm}_train.csv', index=False)\n",
    "            result_train['model'] = sm\n",
    "            result_train['pred_score'] = list(map(lambda x: max(x), \n",
    "                                                  np.array(result_train[['label-0', 'label-1']])))\n",
    "            result_train['pred_label'] = list(map(lambda x: 0 if x[0] > x[1] else 1, \n",
    "                                                  np.array(result_train[['label-0', 'label-1']])))\n",
    "            all_log.append(result_train)\n",
    "            for subset, (X_val_sel, y_val_sel, val_ids) in val_datasets.items():\n",
    "                val_indexes = np.reshape(np.array(val_ids), (-1, 1)).astype(str)\n",
    "                y_val_pred_scores = target.predict_proba(X_val_sel)\n",
    "                result_val = pd.DataFrame(np.concatenate([val_indexes, y_val_pred_scores], axis=1), columns=columns)\n",
    "                result_val.to_csv(f'results/{task_type}_{sm}_{subset}.csv', index=False)\n",
    "                result_val['model'] = sm\n",
    "                result_val['pred_score'] = list(map(lambda x: max(x), np.array(result_val[['label-0', 'label-1']])))\n",
    "                result_val['pred_label'] = list(map(lambda x: 0 if x[0] > x[1] else 1,  \n",
    "                                                    np.array(result_val[['label-0', 'label-1']])))\n",
    "                all_log.append(result_val)\n",
    "all_log = pd.concat(all_log, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "559f98c2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
